#!/usr/bin/env python3
# decode_hhmm.py

import os
# lock down threads to avoid nondeterministic reductions
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("PYTHONHASHSEED", "0")  # no RNG used, but stabilizes any hash-based iteration upstream

import argparse, json, numpy as np
from hhmm_lib import (
    load_pt_records,
    build_top_sequences,
    BottomHMMParams,
    TopHMMParams,
    HHMM,
    decode_hhmm_anchored,
    CANON_TAGS,
)
from sklearn.preprocessing import StandardScaler
    # ^ attributes set manually; we do not fit here
from sklearn.decomposition import PCA

# ---------- preprocessing (scaler + optional PCA) ----------
def load_preproc(z):
    scaler = StandardScaler()
    scaler.mean_  = z["prep_mean"]
    scaler.scale_ = z["prep_scale"]
    scaler.var_   = scaler.scale_ ** 2
    scaler.n_features_in_ = scaler.mean_.shape[0]

    pca = None
    if "prep_pca_components" in z and z["prep_pca_components"].size > 0:
        comps = z["prep_pca_components"]   # [k, D_in]
        mean  = z["prep_pca_mean"]         # [D_in]
        k = int(comps.shape[0])
        Din = int(mean.shape[0])

        pca = PCA(n_components=k, svd_solver="full")
        # set learned attrs directly (no fitting here)
        pca.components_ = comps
        pca.mean_ = mean
        pca.n_features_in_ = Din
        pca.explained_variance_ = z.get("prep_pca_explained_variance", np.ones(k, float))
        pca.explained_variance_ratio_ = z.get("prep_pca_explained_variance_ratio", np.ones(k, float)/k)
        pca.singular_values_ = z.get("prep_pca_singular_values", np.ones(k, float))
    return scaler, pca

# ---------- model loader (C, K, D) ----------
def load_model(npz_path: str) -> HHMM:
    z = np.load(npz_path)
    C = int(z["C"][0]); K = int(z["K"][0]); D = int(z["D"][0])
    top = TopHMMParams(z["top_start"], z["top_trans"])
    bottom = []
    for c in range(C):
        bp = BottomHMMParams(
            startprob=z[f"b{c}_start"],
            transmat =z[f"b{c}_trans"],
            means    =z[f"b{c}_means"],
            variances=z[f"b{c}_vars"],
        )
        bottom.append(bp)
    return HHMM(C, K, D, top, bottom)

def _rec_sort_key(r):
    # stable ordering for deterministic output
    return (
        r.get("sample_id", None),
        r.get("i", None),
        r.get("id", None),
        (r.get("prompt", "") or "")[:32],
    )

def _seq_sort_key(s):
    return (
        s.get("sample_id", None),
        s.get("i", None),
        len(s.get("steps", [])),
    )

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in_pt", required=True)
    ap.add_argument("--model_npz", required=True)
    ap.add_argument("--out_json", default="hhmm_decoded.json")
    ap.add_argument("--label_key", default="sentences_with_labels",
                    help="PT field containing per-step labels; default: sentences_with_labels")
    ap.add_argument("--emit_strings", action="store_true",
                    help="Also output human-readable tag strings alongside category IDs.")
    ap.add_argument("--subset", choices=["all", "correct", "incorrect"], default="all",
                    help="Which data to decode")
    args = ap.parse_args()

    # ----- load & subset records (then sort deterministically) -----
    recs_all = load_pt_records(args.in_pt)
    recs_all = sorted(recs_all, key=_rec_sort_key)
    recs = recs_all if args.subset == "all" else [
        r for r in recs_all if bool(r.get("is_correct", False)) == (args.subset == "correct")
    ]
    print(f"[decode_hhmm] subset={args.subset}, num_records={len(recs)}")

    # ----- load model & preproc -----
    model = load_model(args.model_npz)
    with np.load(args.model_npz, allow_pickle=True) as z_pre:
        scaler, pca = load_preproc(z_pre)

    # Sanity: preproc output dim must match model.D
    preproc_out_D = (pca.components_.shape[0] if pca is not None else scaler.n_features_in_)
    if preproc_out_D != model.D:
        raise ValueError(f"[decode] Model expects D={model.D}, but preproc yields D={preproc_out_D}")

    # ----- build raw sequences (IDs already carried by patched build_top_sequences), sort, and preprocess -----
    raw_seqs = build_top_sequences(recs)
    raw_seqs = sorted(raw_seqs, key=_seq_sort_key)

    seqs = []
    kept, dropped = 0, 0
    for seq in raw_seqs:
        steps_z = []
        for x in seq["steps"]:
            # Input dim must match scaler expectation
            if x.shape[1] != scaler.n_features_in_:
                raise ValueError(f"[decode] Hidden dim {x.shape[1]} != scaler.n_features_in_ {scaler.n_features_in_}")
            X = scaler.transform(x)
            if pca is not None:
                if X.shape[1] != pca.n_features_in_:
                    raise ValueError(f"[decode] Scaled dim {X.shape[1]} != pca.n_features_in_ {pca.n_features_in_}")
                X = pca.transform(X)
            if X.shape[1] != model.D:
                raise ValueError(f"[decode] Post-preproc dim {X.shape[1]} != model.D {model.D}")
            steps_z.append(X.astype(np.float64))

        if args.label_key not in seq:
            dropped += 1
            continue

        # carry IDs forward
        out = {"steps": steps_z, args.label_key: seq[args.label_key]}
        for k in ("sample_id", "i", "id", "prompt"):
            if k in seq:
                out[k] = seq[k]

        seqs.append(out); kept += 1

    if kept == 0:
        raise RuntimeError(f"Anchored decode requested but no sequences have '{args.label_key}'.")
    print(f"Anchored decode: using {kept} labeled sequences, dropped {dropped} unlabeled.")

    # ----- anchored decode -----
    decoded_core = decode_hhmm_anchored(seqs, model, label_key=args.label_key)

    # Reattach IDs to decoded output (preserves 1:1 order with seqs)
    decoded_with_ids = []
    for src, dec in zip(seqs, decoded_core):
        item = {
            "sample_id": src.get("sample_id"),
            "i": src.get("i"),
            "id": src.get("id"),
            "prompt": (src.get("prompt") if "prompt" in src else None),
            "best_categories": dec["best_categories"],
            "best_regimes_per_step": dec["best_regimes_per_step"],
        }
        decoded_with_ids.append(item)

    # ----- optional pretty tag names -----
    if args.emit_strings:
        pretty = []
        for dec in decoded_with_ids:
            cats_ids = dec["best_categories"]
            cats_str = [CANON_TAGS[i] if 0 <= i < len(CANON_TAGS) else f"cat_{i}" for i in cats_ids]
            pretty.append({
                **dec,
                "best_categories_str": cats_str,
            })
        decoded_with_ids = pretty

    # ----- write (sorted keys → stable JSON) -----
    out = {
        "sequences": decoded_with_ids,
        "meta": {"C": model.C, "K": model.K, "D": model.D, "anchored_decode": True},
        "subset": args.subset,
    }
    with open(args.out_json, "w") as f:
        json.dump(out, f, indent=2, sort_keys=True)
    print(f"Wrote decoded flows to {args.out_json}")

if __name__ == "__main__":
    main()
